import xml.etree.ElementTree as ET
import pickle
import string
import os
import shutil
from os import listdir, getcwd
from os.path import join

sets=[('2012', 'train')]

classes = ["001002014", "001003002", "001003010", "001004002", "001004008", "001004010",
           "001005002", "001006005", "007001002", "007002001", "007002007", "007002008",
           "007002010", "007002015", "007002019", "007002020", "007002021", "007004003",
           "007004005", "007004013", "007004014", "007005001", "010004004", "100304001",
           "110401001", "110401002", "130101001", "130102001", "140103001", "140103002",
           "140103003", "140103004", "140106001",
           "001001001", "001001002", "001001006", "001001007", "001001008", "001002001",
           "001002011", "001002019", "001003003", "001003004", "001003006", "001003007",
           "001003009", "001003011", "001003013", "001003020", "001003021", "001003024",
           "001003025", "001003026", "001003030", "001003031", "001004012", "001005001",
           "001005005", "001005006", "001005007", "001005012", "001006001", "001006006",
           "001007003", "007002004", "007002005", "007002011", "007002013", "007002014",
           "007002018", "007004011", "007004012", "100102002", "100303004", "100401004",
           "100403003", "100404001", "100404002", "100601002", "110301001", "110301002",
           "110301003", "110302001", "110302002", "120102004"]  #分类集，我这里是识别猫，所以是一个cat


def convert(size, box):
    dw = 1./size[0]
    dh = 1./size[1]
    x = (box[0] + box[1])/2.0
    y = (box[2] + box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = round(x*dw,4)
    w = round(w*dw,4)
    y = round(y*dh,4)
    h = round(h*dh,4)
    return (x,y,w,h)

def convert_annotation(image_id,flag,savepath):
    #s = '\xef\xbb\xbf'
    #nPos = image_id.index(s)
    #if nPos >= 0:
     #   image_id = image_id[3:]
    if flag == 0:
        in_file = open(savepath+'/trainImageXML/%s.xml' % (image_id))
        labeltxt = savepath+'/trainImageLabelTxt';
        if os.path.exists(labeltxt) == False:
            os.mkdir(labeltxt);
        out_file = open(savepath+'/trainImageLabelTxt/%s.txt' % (image_id), 'w')
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
    elif flag == 1:
        in_file = open(savepath+'/validateImageXML/%s.xml' % (image_id))
        labeltxt = savepath + '/validateImageLabelTxt';
        if os.path.exists(labeltxt) == False:
            os.mkdir(labeltxt);
        out_file = open(savepath+'/validateImageLabelTxt/%s.txt' % (image_id), 'w')
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)



    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        print (b)
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd()

for year, image_set in sets:
    #savepath = "/home/wurui/CAR/wrz/pillar";
    savepath = os.getcwd();
    idtxt = savepath + "/validateImageId.txt";
    pathtxt = savepath + "/validateImagePath.txt";
    image_ids = open(idtxt).read().strip().split()
    list_file = open(pathtxt, 'w')
    s = '\xef\xbb\xbf'
    for image_id in image_ids:
        nPos = image_id.find(s)
        if nPos >= 0:
            image_id = image_id[3:]
        list_file.write('%s/validateImage/%s.jpg\n' % (wd, image_id))
        print(image_id)
        convert_annotation(image_id, 1, savepath)
    list_file.close()

    idtxt = savepath + "/trainImageId.txt";
    pathtxt = savepath + "/trainImagePath.txt" ;
    image_ids = open(idtxt).read().strip().split()
    list_file = open(pathtxt, 'w')
    s = '\xef\xbb\xbf'
    for image_id in image_ids:
        nPos = image_id.find(s)
        if nPos >= 0:
           image_id = image_id[3:]
        list_file.write('%s/trainImage/%s.jpg\n'%(wd,image_id))
        print(image_id)
        convert_annotation(image_id,0,savepath)
    list_file.close()

